
#include "statsxx/machine_learning/restricted_Boltzmann_machine/RBM.hpp"

// STL
#include <cmath> // std::log()
#include <tuple> // std::tuple<>, std::make_tuple()

// jScience
#include "jScience/linalg.hpp" // Vector<>, Matrix<>, transpose()

// stats++
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic


// deterministic reconstruction of input
// determistic reconstruction means that we do not sample from hidden activations
// Reconstruction error is computed as follows: for a sample v, the probabilities of the hidden units h are first computed; from h, and without sampling, the visible units are reconstructed v' (also without sampling); the error is then computed using either the cross entropy (for binomial input data) or the square difference (otherwise), as an average over the errors of all visible units.
inline std::tuple<
                 Vector<double>, // visible units (reconstructed)
                 Vector<double>, // hidden units
                 double          // reconstruction error
                 > restricted_Boltzmann_machine::RBM::reconstruct(
                                                                  const Vector<double> &x,   // data
                                                                  const bool x_binomial      // whether the data is binomial
                                                                  ) const
{
    // note: this is used to prevent trying to take the logarithm of 0
    static const double small = 1.0e-10;
    
    //=========================================================
    
    activation_function::Logistic logistic;
    
    Vector<double> h = logistic.f(c + W*x);
    
    // note: (for determistic reconstruction) do NOT sample h
    
    Vector<double> v = logistic.f(b + transpose(W)*h);
    
    // compute the error ...
    double error = 0.0;
    
    if(x_binomial)
    {
        // for binomial data, calculate the error using the cross entropy ...
        for(auto i = 0; i < x.size(); ++i)
        {
            error -= (x(i)*std::log(v(i) + small) + (1.0 - x(i))*std::log(1.0 - v(i) + small));
        }
    }
    else
    {
        // ... else recompute the error as the mean square error
        for(auto i = 0; i < x.size(); ++i)
        {
            double diff = x(i) - v(i);
            error += (diff*diff);
        }
    }
    
    error /= x.size();
    
    return std::make_tuple(v, h, error);
}

